{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# Tutorial 1: Basics of CrypTen Tensors\n",
    "\n",
    "We now have a high-level understanding of how secure MPC works. Through these tutorials, we will explain how to use CrypTen to carry out secure operations on encrypted tensors. In this tutorial, we will introduce a fundamental building block in CrypTen, called a ```CrypTensor```.  ```CrypTensors``` are encrypted ```torch``` tensors that can be used for computing securely on data. \n",
    "\n",
    "CrypTen currently only supports secure MPC protocols (though we intend to add support for other advanced encryption protocols). Using the ```mpc``` backend, ```CrypTensors``` act as ```torch``` tensors whose values are encrypted using secure MPC protocols. Tensors created using the ```mpc``` backend are called ```MPCTensors```. We will go into greater detail about ```MPCTensors``` in Tutorial 2. \n",
    "\n",
    "Let's begin by importing ```crypten``` and ```torch``` libraries. (If the imports fail, please see the installation instructions in the README.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import crypten\n",
    "import torch\n",
    "\n",
    "crypten.init()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating Encrypted Tensors\n",
    "CrypTen provides a ```crypten.cryptensor``` factory function, similar to ```torch.tensor```, to make creating ```CrypTensors``` easy. \n",
    "\n",
    "Let's begin by creating a ```torch``` tensor and encrypting it using ```crypten.cryptensor```. To decrypt a ```CrypTensor```, use ```get_plain_text()``` to return the original tensor.  (```CrypTensors``` can also be created directly from a list or an array.)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "module 'crypten' has no attribute 'print'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-2-93121eb73a29>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      7\u001b[0m \u001b[0;31m# Decrypt x\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[0mx_dec\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx_enc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_plain_text\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0mcrypten\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_dec\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mAttributeError\u001b[0m: module 'crypten' has no attribute 'print'"
     ]
    }
   ],
   "source": [
    "# Create torch tensor\n",
    "x = torch.tensor([1.0, 2.0, 3.0])\n",
    "\n",
    "# Encrypt x\n",
    "x_enc = crypten.cryptensor(x)\n",
    "\n",
    "# Decrypt x\n",
    "x_dec = x_enc.get_plain_text()   \n",
    "crypten.print(x_dec)\n",
    "\n",
    "\n",
    "# Create python list\n",
    "y = [4.0, 5.0, 6.0]\n",
    "\n",
    "# Encrypt x\n",
    "y_enc = crypten.cryptensor(y)\n",
    "\n",
    "# Decrypt x\n",
    "y_dec = y_enc.get_plain_text()\n",
    "crypten.print(y_dec)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Operations on Encrypted Tensors\n",
    "Now let's look at what we can do with our ```CrypTensors```.\n",
    "\n",
    "#### Arithmetic Operations\n",
    "We can carry out regular arithmetic operations between ```CrypTensors```, as well as between ```CrypTensors``` and plaintext tensors. Note that these operations never reveal any information about encrypted tensors (internally or externally) and return an encrypted tensor output."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Public  addition: tensor([3., 4., 5.])\n",
      "Private addition: tensor([3., 4., 5.])\n",
      "\n",
      "Public  subtraction: tensor([-1.,  0.,  1.])\n",
      "Private subtraction: tensor([-1.,  0.,  1.])\n",
      "\n",
      "Public  multiplication: tensor([2., 4., 6.])\n",
      "Private multiplication: tensor([2., 4., 6.])\n",
      "\n",
      "Public  division: tensor([0.5000, 1.0000, 1.5000])\n",
      "Private division: tensor([0.5000, 1.0000, 1.5000])\n"
     ]
    }
   ],
   "source": [
    "#Arithmetic operations between CrypTensors and plaintext tensors\n",
    "x_enc = crypten.cryptensor([1.0, 2.0, 3.0])\n",
    "\n",
    "y = 2.0\n",
    "y_enc = crypten.cryptensor(2.0)\n",
    "\n",
    "\n",
    "# Addition\n",
    "z_enc1 = x_enc + y      # Public\n",
    "z_enc2 = x_enc + y_enc  # Private\n",
    "crypten.print(\"\\nPublic  addition:\", z_enc1.get_plain_text())\n",
    "crypten.print(\"Private addition:\", z_enc2.get_plain_text())\n",
    "\n",
    "\n",
    "# Subtraction\n",
    "z_enc1 = x_enc - y      # Public\n",
    "z_enc2 = x_enc - y_enc  # Private\n",
    "crypten.print(\"\\nPublic  subtraction:\", z_enc1.get_plain_text())\n",
    "print(\"Private subtraction:\", z_enc2.get_plain_text())\n",
    "\n",
    "# Multiplication\n",
    "z_enc1 = x_enc * y      # Public\n",
    "z_enc2 = x_enc * y_enc  # Private\n",
    "print(\"\\nPublic  multiplication:\", z_enc1.get_plain_text())\n",
    "print(\"Private multiplication:\", z_enc2.get_plain_text())\n",
    "\n",
    "# Division\n",
    "z_enc1 = x_enc / y      # Public\n",
    "z_enc2 = x_enc / y_enc  # Private\n",
    "print(\"\\nPublic  division:\", z_enc1.get_plain_text())\n",
    "print(\"Private division:\", z_enc2.get_plain_text())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Comparisons\n",
    "Similarly, we can compute element-wise comparisons on ```CrypTensors```. Like arithmetic operations, comparisons performed on ```CrypTensor```s will return a ```CrypTensor``` result. Decrypting these result ```CrypTensor```s will evaluate to 0's and 1's corresponding to ```False``` and ```True``` values respectively."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x:  tensor([1., 2., 3., 4., 5.])\n",
      "y:  tensor([5., 4., 3., 2., 1.])\n",
      "\n",
      "Public  (x < y) : tensor([1., 1., 0., 0., 0.])\n",
      "Private (x < y) : tensor([1., 1., 0., 0., 0.])\n",
      "\n",
      "Public  (x <= y): tensor([1., 1., 1., 0., 0.])\n",
      "Private (x <= y): tensor([1., 1., 1., 0., 0.])\n",
      "\n",
      "Public  (x > y) : tensor([0., 0., 0., 1., 1.])\n",
      "Private (x > y) : tensor([0., 0., 0., 1., 1.])\n",
      "\n",
      "Public  (x >= y): tensor([0., 0., 1., 1., 1.])\n",
      "Private (x >= y): tensor([0., 0., 1., 1., 1.])\n",
      "\n",
      "Public  (x == y): tensor([0., 0., 1., 0., 0.])\n",
      "Private (x == y): tensor([0., 0., 1., 0., 0.])\n",
      "\n",
      "Public  (x != y): tensor([1., 1., 0., 1., 1.])\n",
      "Private (x != y): tensor([1., 1., 0., 1., 1.])\n"
     ]
    }
   ],
   "source": [
    "#Construct two example CrypTensors\n",
    "x_enc = crypten.cryptensor([1.0, 2.0, 3.0, 4.0, 5.0])\n",
    "\n",
    "y = torch.tensor([5.0, 4.0, 3.0, 2.0, 1.0])\n",
    "y_enc = crypten.cryptensor(y)\n",
    "\n",
    "# Print values:\n",
    "print(\"x: \", x_enc.get_plain_text())\n",
    "print(\"y: \", y_enc.get_plain_text())\n",
    "\n",
    "# Less than\n",
    "z_enc1 = x_enc < y      # Public\n",
    "z_enc2 = x_enc < y_enc  # Private\n",
    "print(\"\\nPublic  (x < y) :\", z_enc1.get_plain_text())\n",
    "print(\"Private (x < y) :\", z_enc2.get_plain_text())\n",
    "\n",
    "# Less than or equal\n",
    "z_enc1 = x_enc <= y      # Public\n",
    "z_enc2 = x_enc <= y_enc  # Private\n",
    "print(\"\\nPublic  (x <= y):\", z_enc1.get_plain_text())\n",
    "print(\"Private (x <= y):\", z_enc2.get_plain_text())\n",
    "\n",
    "# Greater than\n",
    "z_enc1 = x_enc > y      # Public\n",
    "z_enc2 = x_enc > y_enc  # Private\n",
    "print(\"\\nPublic  (x > y) :\", z_enc1.get_plain_text())\n",
    "print(\"Private (x > y) :\", z_enc2.get_plain_text())\n",
    "\n",
    "# Greater than or equal\n",
    "z_enc1 = x_enc >= y      # Public\n",
    "z_enc2 = x_enc >= y_enc  # Private\n",
    "print(\"\\nPublic  (x >= y):\", z_enc1.get_plain_text())\n",
    "print(\"Private (x >= y):\", z_enc2.get_plain_text())\n",
    "\n",
    "# Equal\n",
    "z_enc1 = x_enc == y      # Public\n",
    "z_enc2 = x_enc == y_enc  # Private\n",
    "print(\"\\nPublic  (x == y):\", z_enc1.get_plain_text())\n",
    "print(\"Private (x == y):\", z_enc2.get_plain_text())\n",
    "\n",
    "# Not Equal\n",
    "z_enc1 = x_enc != y      # Public\n",
    "z_enc2 = x_enc != y_enc  # Private\n",
    "print(\"\\nPublic  (x != y):\", z_enc1.get_plain_text())\n",
    "print(\"Private (x != y):\", z_enc2.get_plain_text())\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Advanced mathematics\n",
    "We are also able to compute more advanced mathematical functions on ```CrypTensors``` using iterative approximations. CrypTen provides MPC support for functions like reciprocal, exponential, logarithm, square root, tanh, etc. Notice that these are subject to numerical error due to the approximations used. \n",
    "\n",
    "Additionally, note that some of these functions will fail silently when input values are outside of the range of convergence for the approximations used. These do not produce errors because value are encrypted and cannot be checked without decryption. Exercise caution when using these functions. (It is good practice here to normalize input values for certain models.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Public  reciprocal: tensor([10.0000,  3.3333,  2.0000,  1.0000,  0.6667,  0.5000,  0.4000])\n",
      "Private reciprocal: tensor([10.0009,  3.3335,  2.0000,  1.0000,  0.6667,  0.5000,  0.4000])\n",
      "\n",
      "Public  logarithm: tensor([-2.3026, -1.2040, -0.6931,  0.0000,  0.4055,  0.6931,  0.9163])\n",
      "Private logarithm: tensor([    -2.3181,     -1.2110,     -0.6997,      0.0004,      0.4038,\n",
      "             0.6918,      0.9150])\n",
      "\n",
      "Public  exponential: tensor([ 1.1052,  1.3499,  1.6487,  2.7183,  4.4817,  7.3891, 12.1825])\n",
      "Private exponential: tensor([ 1.1021,  1.3440,  1.6468,  2.7121,  4.4574,  7.3280, 12.0188])\n",
      "\n",
      "Public  square root: tensor([0.3162, 0.5477, 0.7071, 1.0000, 1.2247, 1.4142, 1.5811])\n",
      "Private square root: tensor([0.3147, 0.5477, 0.7071, 0.9989, 1.2237, 1.4141, 1.5811])\n",
      "\n",
      "Public  tanh: tensor([0.0997, 0.2913, 0.4621, 0.7616, 0.9051, 0.9640, 0.9866])\n",
      "Private tanh: tensor([0.0994, 0.2914, 0.4636, 0.7636, 0.9069, 0.9652, 0.9873])\n"
     ]
    }
   ],
   "source": [
    "torch.set_printoptions(sci_mode=False)\n",
    "\n",
    "#Construct example input CrypTensor\n",
    "x = torch.tensor([0.1, 0.3, 0.5, 1.0, 1.5, 2.0, 2.5])\n",
    "x_enc = crypten.cryptensor(x)\n",
    "\n",
    "# Reciprocal\n",
    "z = x.reciprocal()          # Public\n",
    "z_enc = x_enc.reciprocal()  # Private\n",
    "print(\"\\nPublic  reciprocal:\", z)\n",
    "print(\"Private reciprocal:\", z_enc.get_plain_text())\n",
    "\n",
    "# Logarithm\n",
    "z = x.log()          # Public\n",
    "z_enc = x_enc.log()  # Private\n",
    "print(\"\\nPublic  logarithm:\", z)\n",
    "print(\"Private logarithm:\", z_enc.get_plain_text())\n",
    "\n",
    "# Exp\n",
    "z = x.exp()          # Public\n",
    "z_enc = x_enc.exp()  # Private\n",
    "print(\"\\nPublic  exponential:\", z)\n",
    "print(\"Private exponential:\", z_enc.get_plain_text())\n",
    "\n",
    "# Sqrt\n",
    "z = x.sqrt()          # Public\n",
    "z_enc = x_enc.sqrt()  # Private\n",
    "print(\"\\nPublic  square root:\", z)\n",
    "print(\"Private square root:\", z_enc.get_plain_text())\n",
    "\n",
    "# Tanh\n",
    "z = x.tanh()          # Public\n",
    "z_enc = x_enc.tanh()  # Private\n",
    "print(\"\\nPublic  tanh:\", z)\n",
    "print(\"Private tanh:\", z_enc.get_plain_text())\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Control Flow using Encrypted Tensors\n",
    "\n",
    "Note that ```CrypTensors``` cannot be used directly in conditional expressions. Because the tensor is encrypted, the boolean expression cannot be evaluated unless the tensor is decrypted first. Attempting to execute control flow using an encrypted condition will result in an error.\n",
    "\n",
    "Some control flow can still be executed without decrypting, but must be executed using mathematical expressions. We have provided the function ```crypten.where(condition, x, y)``` to abstract this kind of conditional value setting.\n",
    "\n",
    "The following example illustrates how to write this kind conditional logic for ```CrypTensors```."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RuntimeError caught: \"Cannot evaluate MPCTensors to boolean values\"\n",
      "\n",
      "z: tensor(2.)\n",
      "z: tensor(2.)\n"
     ]
    }
   ],
   "source": [
    "x_enc = crypten.cryptensor(2.0)\n",
    "y_enc = crypten.cryptensor(4.0)\n",
    "\n",
    "a, b = 2, 3\n",
    "\n",
    "# Normal Control-flow code will raise an error\n",
    "try:\n",
    "    if x_enc < y_enc:\n",
    "        z = a\n",
    "    else:\n",
    "        z = b\n",
    "except RuntimeError as error:\n",
    "    print(f\"RuntimeError caught: \\\"{error}\\\"\\n\")\n",
    "\n",
    "    \n",
    "# Instead use a mathematical expression\n",
    "use_a = (x_enc < y_enc)\n",
    "z_enc = use_a * a + (1 - use_a) * b\n",
    "print(\"z:\", z_enc.get_plain_text())\n",
    "    \n",
    "    \n",
    "# Or use the `where` function\n",
    "z_enc = crypten.where(x_enc < y_enc, a, b)\n",
    "print(\"z:\", z_enc.get_plain_text())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Advanced Indexing\n",
    "CrypTen supports many of the operations that work on ```torch``` tensors. Encrypted tensors can be indexed, concatenated, stacked, reshaped, etc. For a full list of operations, see the CrypTen documentation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Indexing:\n",
      " tensor([1., 2.])\n",
      "\n",
      "Concatenation:\n",
      " tensor([1., 2., 3., 4., 5., 6.])\n",
      "\n",
      "Stacking:\n",
      " tensor([[1., 2., 3.],\n",
      "        [4., 5., 6.]])\n",
      "\n",
      "Reshaping:\n",
      " tensor([[1., 2., 3., 4., 5., 6.]])\n"
     ]
    }
   ],
   "source": [
    "x_enc = crypten.cryptensor([1.0, 2.0, 3.0])\n",
    "y_enc = crypten.cryptensor([4.0, 5.0, 6.0])\n",
    "\n",
    "# Indexing\n",
    "z_enc = x_enc[:-1]\n",
    "print(\"Indexing:\\n\", z_enc.get_plain_text())\n",
    "\n",
    "# Concatenation\n",
    "z_enc = crypten.cat([x_enc, y_enc])\n",
    "print(\"\\nConcatenation:\\n\", z_enc.get_plain_text())\n",
    "\n",
    "# Stacking\n",
    "z_enc = crypten.stack([x_enc, y_enc])\n",
    "print('\\nStacking:\\n', z_enc.get_plain_text())\n",
    "\n",
    "# Reshaping\n",
    "w_enc = z_enc.reshape(-1, 6)\n",
    "print('\\nReshaping:\\n', w_enc.get_plain_text())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "### Implementation Note\n",
    "\n",
    "Due to internal implementation details, ```CrypTensors``` must be the first operand of operations that combine ```CrypTensors``` and ```torch``` tensors. That is, for a ```CrypTensor``` ```x_enc``` and a plaintext tensor ```y```:\n",
    "- The expression ```x_enc < y``` is valid, but the equivalent expression ```y > x_enc``` will result in an error.\n",
    "- The expression ```x_enc + y``` is valid, but the equivalent expression ```y + x_enc``` will result in an error.\n",
    "\n",
    "We intend to add support for both expressions in the future."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
